{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c50e24b1-af2d-4098-82cb-056e5560a78f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-09T05:52:04.229194Z",
     "iopub.status.busy": "2024-01-09T05:52:04.228352Z",
     "iopub.status.idle": "2024-01-09T05:52:04.233790Z",
     "shell.execute_reply": "2024-01-09T05:52:04.233085Z",
     "shell.execute_reply.started": "2024-01-09T05:52:04.229159Z"
    }
   },
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "from llama_index import SimpleDirectoryReader\n",
    "from llama_index.node_parser import SentenceSplitter\n",
    "from llama_index.schema import MetadataMode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ca4963ea-966b-49ac-b8b5-17689b4ec0a3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-09T05:52:07.163714Z",
     "iopub.status.busy": "2024-01-09T05:52:07.163210Z",
     "iopub.status.idle": "2024-01-09T05:52:07.167490Z",
     "shell.execute_reply": "2024-01-09T05:52:07.166800Z",
     "shell.execute_reply.started": "2024-01-09T05:52:07.163683Z"
    }
   },
   "outputs": [],
   "source": [
    "TRAIN_FILES = [\"train.txt\"]\n",
    "VAL_FILES = [\"test.txt\"]\n",
    "\n",
    "TRAIN_CORPUS_FPATH = \"train_corpus.json\"\n",
    "VAL_CORPUS_FPATH = \"val_corpus.json\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "be5d9621-2e82-496c-b3e6-64c403c51f60",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-09T05:52:07.862139Z",
     "iopub.status.busy": "2024-01-09T05:52:07.860880Z",
     "iopub.status.idle": "2024-01-09T05:52:07.869680Z",
     "shell.execute_reply": "2024-01-09T05:52:07.868223Z",
     "shell.execute_reply.started": "2024-01-09T05:52:07.862096Z"
    }
   },
   "outputs": [],
   "source": [
    "def load_corpus(files, verbose=False):\n",
    "    if verbose:\n",
    "        print(f\"Loading files {files}\")\n",
    "\n",
    "    reader = SimpleDirectoryReader(input_files=files)\n",
    "    docs = reader.load_data()\n",
    "    if verbose:\n",
    "        print(f\"Loaded {len(docs)} docs\")\n",
    "\n",
    "    parser = SentenceSplitter(chunk_size=250, chunk_overlap=0)\n",
    "    nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)\n",
    "\n",
    "    if verbose:\n",
    "        print(f\"Parsed {len(nodes)} nodes\")\n",
    "\n",
    "    return nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "dea8767a-69f2-4c06-847b-464127872237",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-09T05:52:11.452101Z",
     "iopub.status.busy": "2024-01-09T05:52:11.451493Z",
     "iopub.status.idle": "2024-01-09T05:52:11.862828Z",
     "shell.execute_reply": "2024-01-09T05:52:11.862441Z",
     "shell.execute_reply.started": "2024-01-09T05:52:11.452067Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading files ['train.txt']\n",
      "Loaded 1 docs\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d49144fa9bdd4586957406e8a4c8633b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Parsing nodes:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parsed 129 nodes\n",
      "Loading files ['test.txt']\n",
      "Loaded 1 docs\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4a17bb71714c412ead92e6010a2182c5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Parsing nodes:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parsed 107 nodes\n"
     ]
    }
   ],
   "source": [
    "train_nodes = load_corpus(TRAIN_FILES, verbose=True)\n",
    "val_nodes = load_corpus(VAL_FILES, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5382543f-430b-446a-b1a4-3f121b0177b0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-09T05:52:17.389122Z",
     "iopub.status.busy": "2024-01-09T05:52:17.388703Z",
     "iopub.status.idle": "2024-01-09T05:52:17.394748Z",
     "shell.execute_reply": "2024-01-09T05:52:17.394249Z",
     "shell.execute_reply.started": "2024-01-09T05:52:17.389096Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[TextNode(id_='065c7c68-64f1-41e9-9b5f-6d8141aae864', embedding=None, metadata={'file_path': 'train.txt', 'file_name': 'train.txt', 'file_type': 'text/plain', 'file_size': 66966, 'creation_date': '2024-01-09', 'last_modified_date': '2024-01-09', 'last_accessed_date': '2024-01-09'}, excluded_embed_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], excluded_llm_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='008c3477-fbe1-4da1-86a9-91d83316333d', node_type=<ObjectType.DOCUMENT: '4'>, metadata={'file_path': 'train.txt', 'file_name': 'train.txt', 'file_type': 'text/plain', 'file_size': 66966, 'creation_date': '2024-01-09', 'last_modified_date': '2024-01-09', 'last_accessed_date': '2024-01-09'}, hash='77b3142f61c86cad975ca9bc682650512f3a0498a97fb38e6a5b3721324a80c7'), <NodeRelationship.NEXT: '3'>: RelatedNodeInfo(node_id='ff466b80-aee4-4e14-9aa3-8becdbaa3a88', node_type=<ObjectType.TEXT: '1'>, metadata={}, hash='c2950b491d0515bb6385ef1831baaa2bd4e848f9e12a831e31bdccf00200172f')}, hash='5568080c7f77966aa8e31768c5ef75d877168f8501d8adf530b6db5d72886096', text='受半导体行业周期“磨底”、消费电子市场需求恢复缓慢等影响，今年A股半导体行业上市公司半年度业绩预告显示，归母净利润普遍同比下滑，IC设计、封测等环节成为“重灾区”， 。环比来看，部分头部企业第二季度业绩已经企稳复苏，盈利环比增长，人工智能、汽车电子、电网等板块贡献业绩，有公司表示下半年将企稳增长。', start_char_idx=0, end_char_idx=149, text_template='{metadata_str}\\n\\n{content}', metadata_template='{key}: {value}', metadata_seperator='\\n'),\n",
       " TextNode(id_='ff466b80-aee4-4e14-9aa3-8becdbaa3a88', embedding=None, metadata={'file_path': 'train.txt', 'file_name': 'train.txt', 'file_type': 'text/plain', 'file_size': 66966, 'creation_date': '2024-01-09', 'last_modified_date': '2024-01-09', 'last_accessed_date': '2024-01-09'}, excluded_embed_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], excluded_llm_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='008c3477-fbe1-4da1-86a9-91d83316333d', node_type=<ObjectType.DOCUMENT: '4'>, metadata={'file_path': 'train.txt', 'file_name': 'train.txt', 'file_type': 'text/plain', 'file_size': 66966, 'creation_date': '2024-01-09', 'last_modified_date': '2024-01-09', 'last_accessed_date': '2024-01-09'}, hash='77b3142f61c86cad975ca9bc682650512f3a0498a97fb38e6a5b3721324a80c7'), <NodeRelationship.PREVIOUS: '2'>: RelatedNodeInfo(node_id='065c7c68-64f1-41e9-9b5f-6d8141aae864', node_type=<ObjectType.TEXT: '1'>, metadata={'file_path': 'train.txt', 'file_name': 'train.txt', 'file_type': 'text/plain', 'file_size': 66966, 'creation_date': '2024-01-09', 'last_modified_date': '2024-01-09', 'last_accessed_date': '2024-01-09'}, hash='5568080c7f77966aa8e31768c5ef75d877168f8501d8adf530b6db5d72886096'), <NodeRelationship.NEXT: '3'>: RelatedNodeInfo(node_id='9b42373d-7002-4cbe-b7ea-ea855696124e', node_type=<ObjectType.TEXT: '1'>, metadata={}, hash='136ce81b42b41669ca89fa26ec3d4adb9158e83b6c80cc61ab7cec118a83007d')}, hash='c2950b491d0515bb6385ef1831baaa2bd4e848f9e12a831e31bdccf00200172f', text='据Choice金融终端统计，目前超过30家半导体上市公司披露业绩预告，其中，通富微电、汇顶科技、士兰微、上海贝岭、中晶科技、大为股份等公司业绩预计首亏，博通集成预亏增加，韦尔股份、瑞芯微、华天科技等公司最大降幅超过90%。相比之下，北方华创、中微公司等头部企业翻倍增长。\\n\\n\\u3000\\u3000设计企业：', start_char_idx=153, end_char_idx=297, text_template='{metadata_str}\\n\\n{content}', metadata_template='{key}: {value}', metadata_seperator='\\n'),\n",
       " TextNode(id_='9b42373d-7002-4cbe-b7ea-ea855696124e', embedding=None, metadata={'file_path': 'train.txt', 'file_name': 'train.txt', 'file_type': 'text/plain', 'file_size': 66966, 'creation_date': '2024-01-09', 'last_modified_date': '2024-01-09', 'last_accessed_date': '2024-01-09'}, excluded_embed_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], excluded_llm_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='008c3477-fbe1-4da1-86a9-91d83316333d', node_type=<ObjectType.DOCUMENT: '4'>, metadata={'file_path': 'train.txt', 'file_name': 'train.txt', 'file_type': 'text/plain', 'file_size': 66966, 'creation_date': '2024-01-09', 'last_modified_date': '2024-01-09', 'last_accessed_date': '2024-01-09'}, hash='77b3142f61c86cad975ca9bc682650512f3a0498a97fb38e6a5b3721324a80c7'), <NodeRelationship.PREVIOUS: '2'>: RelatedNodeInfo(node_id='ff466b80-aee4-4e14-9aa3-8becdbaa3a88', node_type=<ObjectType.TEXT: '1'>, metadata={'file_path': 'train.txt', 'file_name': 'train.txt', 'file_type': 'text/plain', 'file_size': 66966, 'creation_date': '2024-01-09', 'last_modified_date': '2024-01-09', 'last_accessed_date': '2024-01-09'}, hash='c2950b491d0515bb6385ef1831baaa2bd4e848f9e12a831e31bdccf00200172f'), <NodeRelationship.NEXT: '3'>: RelatedNodeInfo(node_id='e7dbb434-48ba-4289-a3c6-d0369515dd23', node_type=<ObjectType.TEXT: '1'>, metadata={}, hash='d853cc49bf966d5bf25eae4174a6910b48855e7c516134121075537cbb9f8db9')}, hash='136ce81b42b41669ca89fa26ec3d4adb9158e83b6c80cc61ab7cec118a83007d', text='加速去库存\\n\\n\\u3000\\u3000由于终端消费电子市场低迷，芯片设计企业上半年业绩同比普遍预降，但随着去库存推进，部分企业业绩触底企稳，并在二季度环比增长。\\n\\n\\u3000\\u3000作为AIot（人工智能与物联网）芯片龙头，瑞芯微预计今年上半年实现营业收入约8.58亿元，同比减少约31%，归母净利润2000万元到3000万元，同比减少93%到89%。环比来看，第二季度公司营收增长约六成，归母净利润环比实现扭亏。', start_char_idx=302, end_char_idx=492, text_template='{metadata_str}\\n\\n{content}', metadata_template='{key}: {value}', metadata_seperator='\\n')]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_nodes[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "70f02876-0e48-49dc-bfa8-8853b6e6651f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-01-09T05:51:46.252811Z",
     "iopub.status.busy": "2024-01-09T05:51:46.252196Z",
     "iopub.status.idle": "2024-01-09T05:51:49.289481Z",
     "shell.execute_reply": "2024-01-09T05:51:49.289154Z",
     "shell.execute_reply.started": "2024-01-09T05:51:46.252775Z"
    }
   },
   "outputs": [],
   "source": [
    "from llama_index.finetuning import (\n",
    "    generate_qa_embedding_pairs,\n",
    "    EmbeddingQAFinetuneDataset,\n",
    ")\n",
    "from llama_index.llms import OpenAI\n",
    "import os\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\"\n",
    "llm = OpenAI(model=\"gpt-3.5-turbo\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "160e44d2-29e0-4303-85a0-3fb1102b5074",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 129/129 [08:03<00:00,  3.75s/it]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 107/107 [06:49<00:00,  3.83s/it]\n"
     ]
    }
   ],
   "source": [
    "qa_generate_prompt_tmpl = \"\"\"\\\n",
    "Context information is below.\n",
    "\n",
    "---------------------\n",
    "{context_str}\n",
    "---------------------\n",
    "\n",
    "Given the context information and not prior knowledge.\n",
    "generate only questions based on the below query.\n",
    "\n",
    "You are a Professor. Your task is to setup \\\n",
    "{num_questions_per_chunk} questions for an upcoming \\\n",
    "quiz/examination in Chinese. The questions should be diverse in nature \\\n",
    "across the document in Chinese. The questions should not contain options, not start with Q1/ Q2. \\\n",
    "Restrict the questions to the context information provided.\n",
    "\"\"\"\n",
    "\n",
    "train_dataset = generate_qa_embedding_pairs(nodes=train_nodes, llm=llm, num_questions_per_chunk=1, qa_generate_prompt_tmpl=qa_generate_prompt_tmpl)\n",
    "val_dataset = generate_qa_embedding_pairs(nodes=val_nodes, llm=llm, num_questions_per_chunk=1, qa_generate_prompt_tmpl=qa_generate_prompt_tmpl)\n",
    "\n",
    "train_dataset.save_json(\"train_dataset.json\")\n",
    "val_dataset.save_json(\"val_dataset.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "b0e1401d-6a5d-45d9-a980-0c129ba122a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.finetuning import SentenceTransformersFinetuneEngine\n",
    "\n",
    "finetune_engine = SentenceTransformersFinetuneEngine(\n",
    "    train_dataset,\n",
    "    model_id=\"/data-xgb1/lmj/models/bge-base-zh-v1.5\",\n",
    "    model_output_path=\"/data-xgb1/lmj/models/bge-base-ft-001\",\n",
    "    val_dataset=val_dataset,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "b50495a7-50a8-4adf-93e2-854f07098e04",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1e8718afd38b4b7d8a0c0837f6a999f0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "51005d1caae34aec9c91aa4b01d8089e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Iteration:   0%|          | 0/67 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "180600ef9f3a437fb35b563f44d23557",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Iteration:   0%|          | 0/67 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "finetune_engine.finetune()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "2a289e1c-6660-4d4d-8c40-5f8587a21154",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MultipleNegativesRankingLoss(\n",
       "  (model): SentenceTransformer(\n",
       "    (0): Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel \n",
       "    (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n",
       "    (2): Normalize()\n",
       "  )\n",
       "  (cross_entropy_loss): CrossEntropyLoss()\n",
       ")"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "finetune_engine.loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41faed90-f7cb-4f84-9b92-5f8bbf2491a0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
